package com.luis.toolsuite.isolationforest;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;


public class IForest implements Serializable{
	
	private static final long serialVersionUID = -9146670646162139090L;

	//样本集子采样的数目
    private int subSampleSize;	
    //构成森林的数的列表
	private List<ITree> iTreeList;	
	
	public IForest() {
        this(256);
    }	
	public IForest(int subSample) {
		this.subSampleSize = subSample;
		this.iTreeList = new ArrayList<>();
	}
   
    /**
     * 训练模型，返回一个ITree链表（也就是一个森林）
     * @param samples 训练的样本
     * @param subTreeNum 子树数量
     */
    public void train(double[][] samples, int subTreeNum){
    	
    	if(subSampleSize < 1 || subTreeNum <= 0 || samples == null || samples.length == 0) {
    		throw new IllegalArgumentException("param illegal for iForest train");
    	}

        if (subSampleSize > samples.length) {
            subSampleSize = samples.length;
        }

        int limitHeight = (int) Math.ceil(Math.log(subSampleSize) / Math.log(2));

        ITree iTree;
        int cols = samples[0].length;
        double[][] subSample;

        int baseSeed = IFUtils.IF_SEED;
        Random random = new Random(baseSeed);
        int[] seeds = new int[subTreeNum];
	    for(int i = 0; i < subTreeNum; i++) {
	    	seeds[i] = random.nextInt(Integer.MAX_VALUE);
	    }
	    
	    Random treeRandom = new Random(baseSeed);
        for (int i = 0; i < subTreeNum; i++) {
        	subSample = new double[subSampleSize][cols];
        	random.setSeed(seeds[i]);
        	Integer[] indexs = IFUtils.getUniqueIndex(random, subSampleSize, samples.length);
            for (int j = 0; j < subSampleSize; j++) {
                subSample[j] = samples[indexs[j]];
            }
            iTree = ITree.createITree(subSample, 0, limitHeight, treeRandom);
        	this.iTreeList.add(iTree);
        }
    }

    /**
     * 计算某一个样本的异常指数
     * @param sample 要计算的样本
     * @return 目标样本的异常指数
     */
    public double computeAnomalyScore(double[] sample){
    	if(iTreeList.isEmpty()) {
    		throw new IllegalArgumentException("iTreeList is empty");
    	}
    	if(sample == null || sample.length == 0) {
    		throw new IllegalArgumentException("Sample is null or empty");
    	}
        // 样本在所有iTree上的平均高度（改进后的）
        double ehx = 0;
        double pathLength = 0;
        for (ITree iTree : iTreeList) {
            pathLength = computePathLength(sample, iTree, 0);
            ehx += pathLength;
        }
        ehx /= iTreeList.size();
        double index = ehx / computeCn(subSampleSize);
        return Math.pow(2, -index);
    }

    /**
     * 计算样本sample在ITree上的PathLength
     * @param sample 要计算的目标样本
     * @param iTree 计算使用的Tree
     * @return  pathLength
     */
    private double computePathLength(double[] sample, ITree iTree, int height){
    	
    	if(iTree.lTree == null && iTree.rTree == null) {
    		return height + computeCn(iTree.leafNodes);
    	}
    	double attrValue = sample[iTree.attrIndex];
    	
    	if (attrValue <= iTree.attrValue) {
    		return computePathLength(sample, iTree.lTree, height+1);
    	} else {
    		return computePathLength(sample, iTree.rTree, height+1);
    	}
    }

    // 论文中的 C(n) 的计算方法
    private double computeCn(double n) {
        if (n <= 1) {
            return 0;
        }
        return 2 * (Math.log(n - 1) + 0.5772156649) - 2 * ((n - 1) / n);
    }
}
